package tree

import (
	"log"
	"fmt"
)

const (
	BLACK int = 0
	RED int = 1
)

type node struct {
	left, right, parent *node
	val int
	color int
}

func (n *node) grandparent() *node {
	return n.parent.parent
}

func (n *node) uncle() *node {
	if n.grandparent().left == n.parent {
		return n.grandparent().right
	} else {
		return n.grandparent().left
	}
}

func (n *node) sibling() *node {
	if n == n.parent.left {
		return n.parent.right
	} else {
		return n.parent.left
	}
}

// 找到以n为根的树中最小值节点
func (tree *RBTree) getSmallestChild(n *node) *node {
	p, s := tree.Nil, n
	for s != tree.Nil {
		p = s
		s = s.left
	}
	return p
}

// 旋转前我是右儿子
// 爹变成我的左儿子，我变成爹的爹
func (tree *RBTree) rotateLeft(n *node) {
	parent := n.parent
	if parent == tree.Nil {
		tree.root = n
		return
	}
	gp := n.grandparent()
	parent.right = n.left
	if n.left != tree.Nil {
		n.left.parent = parent
	}
	parent.parent = n
	n.left = parent
	if tree.root == parent {
		tree.root = n
	}
	n.parent = gp
	// modify connection with grandparent
	if gp != tree.Nil {
		if parent == gp.left {
			gp.left = n
		} else {
			gp.right = n
		}
	}
}

// 旋转前我是左儿子
// 我的爹变成我的右儿子
func (tree *RBTree) rotateRight(n *node) {
	parent := n.parent
	gp := n.grandparent()
	
	parent.left = n.right
	if n.right != tree.Nil {
		n.right.parent = parent
	}
	n.right = parent
	n.parent = gp
	parent.parent = n
	// 更换根节点
	if parent == tree.root {
		tree.root = n
	}
	// 更新祖父节点的孩子指针
	if gp != tree.Nil {
		if parent == gp.left {
			gp.left = n
		} else {
			gp.right = n
		}
	}
}

type RBTree struct {
	root *node
	Nil *node
}

func NewRBTree() *RBTree {
	T := &RBTree{
		Nil: &node{
			color: BLACK,
		},
	}
	T.root = T.Nil
	T.root.parent = T.Nil
	return T
}

func (tree *RBTree) RBSearch(val int) *node {
	tmp := tree.root
	for tmp != tree.Nil {
		if val < tmp.val {
			tmp = tmp.left
		} else if val > tmp.val {
			tmp = tmp.right
		} else {
			return tmp
		}
	}
	return tmp
}

func (tree *RBTree) RBInsert(val int) {
	new_node := &node{
		val: val, 
		color: RED, // 保证性质5
		left: tree.Nil, 
		right: tree.Nil,
	}
	// 找爹
	var prev, cur *node = tree.Nil, tree.root
	for cur != tree.Nil {
		prev = cur
		if val > cur.val {
			cur = cur.right
		} else {
			cur = cur.left
		}
	}
	// connect
	new_node.parent = prev
	if prev == tree.Nil {
		tree.root = new_node
	} else if new_node.val < prev.val {
		prev.left = new_node
	} else {
		prev.right = new_node
	}
	// adjust
	tree.insert_case1(new_node)
}

func (tree *RBTree) insert_case1(n *node) {
	if n.parent == tree.Nil {
		// 待插入点是根节点
		n.color = BLACK
	} else {
		tree.insert_case2(n)
	}
}

func (tree *RBTree) insert_case2(n *node) {
	if n.parent.color == BLACK {
		// 红黑树的性质仍然满足
		return
	} else {
		// 父节点是红色的，不能直接插一个新的RED节点
		tree.insert_case3(n)
	}
}

func (tree *RBTree) insert_case3(n *node) {
	if n.uncle() != tree.Nil && n.uncle().color == RED {
		// 叔叔也是红的
		n.parent.color = BLACK
		n.uncle().color = BLACK
		n.grandparent().color = RED // 重新绘制颜色不让红色节点相邻
		tree.insert_case1(n.grandparent()) // 递归调整
	} else {
		tree.insert_case4(n)
	}
}

func (tree *RBTree) insert_case4(n *node) {
	// uncle是黑的或者不存在
	if n == n.parent.right && n.parent == n.grandparent().left {
		// n 是右孩子，parent是左孩子, 形似 《
		tree.rotateLeft(n) // 使得祖孙三代变成 ‘//’
		n = n.left
	} else if n == n.parent.left && n.parent == n.grandparent().right {
		// 形似 》
		tree.rotateRight(n) // 使得祖孙三代变成 ‘\\’
		n = n.right
	}
	tree.insert_case5(n)
}

func (tree *RBTree) insert_case5(n *node) {
	n.parent.color = BLACK
	n.grandparent().color = RED
	// 让爹当爷爷, 爷爷变成自己的兄弟
	if n == n.parent.left && n.parent == n.grandparent().left {
		tree.rotateRight(n.parent)
	} else {
		tree.rotateLeft(n.parent)
	}
}

func (tree *RBTree) RBDelete(val int) bool {
	node := tree.RBSearch(val)
	if node == tree.Nil {
		return false
	}
	if node.right == tree.Nil {
		tree.delete_one_child(node)
	} else {
		replacer := tree.getSmallestChild(node.right) // 找到非Nil叶子节点
		DPrintf("Find the replacer (%v:%v)", replacer.color, replacer.val)
		replacer.val, node.val = node.val, replacer.val // 将replacer节点上移
		tree.delete_one_child(replacer)
	}
	return true
}

func (tree *RBTree) delete_one_child(n *node) {
	DPrintf("Delete the node %v:%v with one child", n.color, n.val)
	// n only has one child
	var child *node
	if n.left == tree.Nil {
		child = n.right
	} else {
		child = n.left
	}
	// delete the last element in the tree
	if n.parent == tree.Nil && n.left == tree.Nil && n.right == tree.Nil {
		DPrintf("delete the last element in the tree")
		n = tree.Nil
		tree.root = n
		return
	}
	// if node is root
	if n.parent == tree.Nil {
		DPrintf("adjust root")
		child.parent = tree.Nil
		tree.root = child
		tree.root.color = BLACK // 根一定是黑的，删了一个根，要补一个黑节点
		return
	}
	// remove the node n between parent and child
	if n.parent.left == n {
		n.parent.left = child
	} else {
		n.parent.right = child
	}
	child.parent = n.parent
	DPrintf("After remove the node")
	if DEBUG > 1 {
		tree.LevelOrderTraversalPrint()
	}
	DPrintf("Begin adjusting ")

	// 对child进行调整，让删除n之后，红黑树的性质依旧能保持
	if n.color == BLACK {
		if child.color == RED {
			child.color = BLACK // 保证性质5，补充一个黑节点
			DPrintf("Finish adjust, change the n.child.color to BLACK")
		} else {
			// child也是黑色
			DPrintf("Begin adjust child , Nil?%v", child==tree.Nil)
			tree.delete_case1(child)
		}
	} else {
		// else n.color == RED
		DPrintf("n is RED, no adjust")
	}
	// 红色节点n直接删除即可，child替代之后，不改变树枝上黑色节点的数量
}

func (tree *RBTree) delete_case1(n *node) {
	if n.parent != tree.Nil {
		tree.delete_case2(n)
	}
	DPrintf("delete case 1 finish adjust, the node(%v:%v) is new root", n.color, n.val)
	// n是新根，直接返回即可
}

func (tree *RBTree) delete_case2(n *node) {
	s := n.sibling()
	if s.color == RED {
		n.parent.color = RED
		s.color = BLACK
		// 让兄弟s当作n.parent的父亲
		if n == n.parent.left {
			tree.rotateLeft(s)
		} else {
			tree.rotateRight(s)
		}
		DPrintf("after delete case2 adjust")
		if DEBUG > 1 {
			tree.LevelOrderTraversalPrint()
		}
	}
	// 此时n和s的其中一个儿子做兄弟，但是路径上的儿子数量不相等（比有n的路径多一个黑节点），继续调整
	tree.delete_case3(n)
}

func (tree *RBTree) delete_case3(n *node) {
	s := n.sibling()
	if s.color == BLACK && n.parent.color == BLACK && s.left.color == BLACK && s.right.color == BLACK {
		s.color = RED // 减少一个黑节点
		// 此时通过n.parent的路径上的黑色节点数量都相等了，
		// 但是对于从parent开始的路径来说，所有路径上的黑节点总数量少了1（相比parent的sibling）
		// 还需要递归调整parent
		DPrintf("after delete adjust 3")
		if DEBUG > 1 {
			tree.LevelOrderTraversalPrint()
		}
		DPrintf("recurrently adjust n.parent")
		tree.delete_case1(n.parent)
	} else {
		tree.delete_case4(n)
	}
}

func (tree *RBTree) delete_case4(n *node) {
	s := n.sibling()
	// n的父亲是红的，s和s的两个儿子都是黑的
	if n.parent.color == RED && s.color == BLACK && s.left.color == BLACK && s.right.color == BLACK {
		n.parent.color = BLACK
		s.color = RED
		DPrintf("Finish delete case 4")
	} else {
		tree.delete_case5(n)
	}
}

func (tree *RBTree) delete_case5(n *node) {
	s := n.sibling()
	if s.color == BLACK { // 一个简化了的条件，s.child must be black due to the delete case2
		if n == n.parent.left && s.right.color == BLACK && s.left.color == RED {
			s.color = RED
			s.left.color = BLACK
			tree.rotateRight(s.left)
		} else if n == n.parent.right && s.right.color == RED && s.left.color == BLACK {
			s.color = RED
			s.right.color = BLACK
			tree.rotateLeft(s.right)
		}
		DPrintf("after delete case 5")
		if DEBUG > 1 {
			tree.LevelOrderTraversalPrint()
		}
	}
	tree.delete_case6(n)
}

// todo
func (tree *RBTree) delete_case6(n *node) {
	s := n.sibling()
	s.color = n.parent.color
	n.parent.color = BLACK

	if n == n.parent.left {
		s.right.color = BLACK
		tree.rotateLeft(s)
	} else {
		s.left.color = BLACK
		tree.rotateRight(s)
	}
	DPrintf("finish detet case 6")
}

func (tree *RBTree) LevelOrderTraversalPrint() {
	queue := []*node{}
	if tree.root == tree.Nil {
		fmt.Printf("Nil\n")
		return
	}
	queue = append(queue, tree.root)
	for len(queue) > 0 {
		count := len(queue)
		for i := 0; i < count; i++ {
			n := queue[i]
			if n == tree.Nil {
				fmt.Printf("Nil ")
			} else {
				fmt.Printf("%v:%d ", n.color, n.val)
				queue = append(queue, n.left)
				queue = append(queue, n.right)
			}
		}
		queue = queue[count:]
		fmt.Println("")
	}
}

const DEBUG int = 1
func DPrintf(format string, a ...interface{}) (n int, err error) {
	if DEBUG > 0 {
		log.Printf(format, a...)
	}
	return
}

// others functions
// 获取按从小到大排序后的val值的后续
func (tree *RBTree) GetSuccessor(val int) *node {
	n := tree.RBSearch(val)
	if n == tree.Nil {
		return n
	}
	if n.right != tree.Nil {
		return tree.getSmallestChild(n.right)
	}
	p := n.parent
	for p != tree.Nil && n == p.right {
		n = p
		p = p.parent
	}
	return p
}

// 获取按从小到大排序后的val值的前序
func (tree *RBTree) GetPredecessor(val int) *node {
	n := tree.RBSearch(val)
	if n == tree.Nil {
		return n
	}
	if n.left != tree.Nil {
		return tree.getBiggestChild(n.left)
	}
	// 查找父节点，直到找到值小于n
	p := n.parent
	for p != tree.Nil && n == p.left { // 如果n是parent的左孩子，继续向上查找，直到nil或者发现了个转折点，这个点的value一定小于n
		n = p
		p = p.parent
	}
	return p
}

func (tree *RBTree) getBiggestChild(n *node) *node {
	p, c := tree.Nil, n
	for c != tree.Nil {
		p = c
		c = c.right
	}
	return p
}

func (tree *RBTree) GetNil() *node {
	return tree.Nil
}

func (n *node) GetVal() int {
	return n.val
}